{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "PnWDOzZelHlE",
   "metadata": {
    "id": "PnWDOzZelHlE"
   },
   "source": [
    "# 🌱 Gemma 2B Instruction Agent\n",
    "\n",
    "**Goal:** You will learn how to do data prep, how to train, how to run the model, and how to save it using Google’s `gemma-2b-it` open-source model.\n",
    "\n",
    "---\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4qumJb5qs9CF",
   "metadata": {
    "id": "4qumJb5qs9CF"
   },
   "source": [
    "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DhivyaBharathy-web/PraisonAI/blob/main/examples/cookbooks/Gemma2B_Instruction_Agent.ipynb)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "qdsqiQDosXmA",
   "metadata": {
    "id": "qdsqiQDosXmA"
   },
   "source": [
    "#  Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6G4RPDe3lHlH",
   "metadata": {
    "id": "6G4RPDe3lHlH"
   },
   "outputs": [],
   "source": [
    "!pip install transformers accelerate datasets bitsandbytes -q"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "mop0KNO7sdEl",
   "metadata": {
    "id": "mop0KNO7sdEl"
   },
   "source": [
    "# Tools & Model Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "spOeRX2UlHlI",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 369,
     "referenced_widgets": [
      "a2db67e368d94f579c59b36c6ce7af33",
      "5816fa4288cd42d6b15094e2c79f760f",
      "5d6fe9b464f944cdbb00b9936669e737",
      "19920e5b8157423ba5d173dd376e52e6",
      "bd7157974bc5485ab5b6db54facb9721",
      "0063e7320c98440d91ae475596c5cd90",
      "a0f045bc54754d5a85ed6bf7b6b33c1a",
      "9174c5aff19b4caab8a38fbf3acb2449",
      "3de0f0b7ecb34c488145b24e30418e4b",
      "d4457543819d468eb916f126382184b0",
      "08f90c04f0de4b1db57624db4b43e17d",
      "269373e971704b9ea2e10e261b1c56d1",
      "d54d35c945ce4f75af6f83f37dd6af6c",
      "b828857fecdc40fa91f6a8633f0f20b4",
      "b5e70949aea24ae9a75e45aa291cc3aa",
      "9c556838b24246a8b39b1f3f9babfe4a",
      "062b64d17b7043c1bccc419280c9db0a",
      "138ae14327184f65896c29a424bad18d",
      "cb65fa660df54218ad804815dfba4071",
      "93d2454f1b564896bef50906e859c660",
      "64e7861e6afc4b57b28063096533aac8",
      "13fe0b8a8fd14a3684ae5d9beefb3f76",
      "d1ec61e51d4c4b07b301959c4919f2a2",
      "a45e9ed3d46645a18ca1f6b06f875ec9",
      "a364770555cf40fdab5ff7bb85997127",
      "12b3225aeb8c4f61b23b588738d48683",
      "dca4952bae4c49ddb2fec2c36726a34d",
      "163df8921f04425caed240465d8bd8b7",
      "1a09718a80cd4bb5a9e985a3f599c614",
      "71c719d57e4f4e23bd520f80e2e07e33",
      "5e8c8eb7eba948fc8a39b601e9ac658e",
      "db4fa2c97b7247c4826a1a9f69c1bed3",
      "42a6f9c136584cc98fc2be973d9edaec",
      "c1bc8cf2ed6542c49c1940673a662b1d",
      "365b69634e2846a09b91de6b61d0719b",
      "13eb69bc1c3c432db6c504856f3a0f7a",
      "56444404d97549848ed729ee8bc14ec8",
      "f0303454544d42a1ada2213a085a4ac0",
      "c717969ed1b145a9af67d74482d49eab",
      "9e9a826dfb264e44aa69c553d978729d",
      "f2b4ff56a96c4d7e9f56668f6b7b5213",
      "f71f0a3ccc654df288d7489abca0ffc0",
      "5ab0574ed1db47e9b6f43e27acd6c2b2",
      "426fae423dbb43ef9252e6fcfb099c6f",
      "2e1c3ca06b2948f0bcf53a975dd67aec",
      "3685a74a466a4ca0b5f05e7aee631d74",
      "7bda5a4e803c4cbc8da67da8a565725b",
      "5450fde9a54d450d90d14b009047b522",
      "e827839bf4d34d2193776d1ee90d27ad",
      "3d9854af016a4c3c8a582fc7f0f0b0bf",
      "7045a11d5f0f4b2d8ff224bb76a9c943",
      "32e568b25c2d447a8fa63055cb44c75c",
      "ebd9d55db82542e5ae6bb9e715e36a6a",
      "798a57883792465ab024dc19a58a54be",
      "c9d1133770b44b22bda89f3d0eecd4f8",
      "c040fa1bbd9f442884776fdf33efb4a2",
      "c57e497d4fd546ff950eb2c898b94c88",
      "5ae22935b7374219a16272cef31d6d9e",
      "aae2686ee64740218d046ec37e50b584",
      "87b4c953163a4cff8a88b76f40da9de8",
      "30ee0e766fae44bfb570ae8adfba1a54",
      "1c2db67dfc904e30a4013d9fe0253a4c",
      "f6c5380bba0f40a4bea9b78d205eb726",
      "0de9f765500b45428f828ad173dd279e",
      "0663e5cfc20c496993b7c0054ce3d1a1",
      "7ce5d369129b461681f424e52415e534",
      "dfa90d76c670436a80f1a4fd0971b171",
      "0694216654ac41909d567a8c4af247de",
      "658048f25b4a4ad6a3ee66e97a542cc9",
      "0edccfd243f940cfad9723b1d3f5a97c",
      "32d473040a6744718b1b35543f2b3a0f",
      "eadbfe846763427e98a40597560d8519",
      "3cec8231bccb4bcf8d657ee8d2120f79",
      "8bd43207fa8f4b309590f605ad81354f",
      "8821912860f04105959ceb42e53de948",
      "3666b044c757439f948f6942d7682c2f",
      "d47e7c7504374fee82eef9f8cbe24eff",
      "9c6d577f113f4f4baaa1f0e001549862",
      "b0c88d34b0df46b7a3da8db42ff3274e",
      "e4e4a2d0422b45c09c8691ce14ba9b9a",
      "d9b1de1b37f14d70bc56a5f6cb468c9b",
      "cb95b11368984236937fcec7cb689b04",
      "69e70bfdafd4459491bad75651028626",
      "a3b9f03c1f014691b5959fb9012c0bab",
      "088da12bcff64297ad82dc7b2d01709e",
      "8a69dc8aa84e49b1aabb58eddca1eee3",
      "263e458b8a5f405bb1ee365fe561309d",
      "aa3039e03447492da11f86185384470b",
      "3c22264e11c0402fb07f56a0763dc055",
      "db2c5746b8cc48d9a059bcef3b55217d",
      "ad3ab2ab81f6469485a3072bd1ba6a4a",
      "74ab1c79c43041c7ac314d62a6da04bf",
      "ad1a3719b1814f6b88202e131ae923c7",
      "d1cb72251682425fa6ca68341a084f95",
      "895914fca4204593b1a4314e50b8e5a1",
      "a8374dc19b294b6d8f5f92bee5196a1a",
      "b7bef910950949fbac65fd996ad058e2",
      "ce6b7aadbd904b3caa93d927ced6beb3",
      "ba6a9e31ddf7419db317a5d27a1567bf",
      "f271e9c78dcf4ab686caac37f18bd6f7",
      "9ed57fe81783476986d6cb98d07e9aa1",
      "8da75a35ab184675b0164682d8217406",
      "f411c4485b844ceb9b3116f209a6fc86",
      "60ac9cd6505141348323b39db7314554",
      "59b2d494f11344879463a3471716bc3f",
      "46abda4b52224874a93b729f5d4d195b",
      "37bc81b97d48474aa0fd7e7739654d86",
      "73da5101d99440abac30dc8c084439d4",
      "ef790742325743fdb8f7cc6822388185",
      "5eccfbe5715f4df9a96cffb2573f916f",
      "3bf293367470430f84f341ef8711460a",
      "317281ca2af54a148fa4cfd8f6e0f382",
      "d655ef37643448aa95253143f8552f58",
      "158e61376dee400e820f47c87d59bac5",
      "d1fbbe8010e340119bf556c75284e1c2",
      "eea1731926554e56a85867de15153f82",
      "766fb34d23b941b2a9e0de153f4cb031",
      "beb52b95db564105bbd8bc2b2d21c1ce",
      "2effd1c45bee4278bd9172310f58fed4",
      "91d6488876284748932b664953c02146",
      "9157d82cca3f45c5b5a666a3d35d6abe"
     ]
    },
    "id": "spOeRX2UlHlI",
    "outputId": "093b74d1-9ece-4db5-d177-3aff6e2d6e0b"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2db67e368d94f579c59b36c6ce7af33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "269373e971704b9ea2e10e261b1c56d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d1ec61e51d4c4b07b301959c4919f2a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c1bc8cf2ed6542c49c1940673a662b1d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e1c3ca06b2948f0bcf53a975dd67aec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c040fa1bbd9f442884776fdf33efb4a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dfa90d76c670436a80f1a4fd0971b171",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9c6d577f113f4f4baaa1f0e001549862",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c22264e11c0402fb07f56a0763dc055",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f271e9c78dcf4ab686caac37f18bd6f7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3bf293367470430f84f341ef8711460a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import login\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer\n",
    "from datasets import load_dataset\n",
    "import torch\n",
    "\n",
    "login(\"Enter your token here\")\n",
    "\n",
    "model_id = \"google/gemma-2-2b-it\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_id,\n",
    "    device_map=\"auto\",              # Automatically selects GPU if available\n",
    "    torch_dtype=torch.float16       # Optimized for performance\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "GaykKFKYspMd",
   "metadata": {
    "id": "GaykKFKYspMd"
   },
   "source": [
    "# Yaml Prompt Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "iLBZ7WPAlHlJ",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "iLBZ7WPAlHlJ",
    "outputId": "a99fc3f2-68e0-4af8-cfbd-3c48a865538f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You are Qwen, a helpful assistant.\n",
      "User: What is the capital of France?\n",
      "Assistant: The capital of France is **Paris**. \n",
      "\n"
     ]
    }
   ],
   "source": [
    "prompt = \"You are Qwen, a helpful assistant.\\nUser: What is the capital of France?\\nAssistant:\"\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "outputs = model.generate(**inputs, max_new_tokens=100)\n",
    "response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "print(response)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "XPqtCExSsvhn",
   "metadata": {
    "id": "XPqtCExSsvhn"
   },
   "source": [
    "# Use a small sample dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ActfjFx0lHlK",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 49,
     "referenced_widgets": [
      "f1028b5dcab4460c81b09802ca7ec89c",
      "ff33dcccaad945e9bf136d0f72e38d74",
      "eec3a11b0605435d8d1e55092f5e5b85",
      "d997b7bc993b4aa58cfafa7c59829ad3",
      "62e2dc508746419ba3c0405a7199c6dc",
      "2a7fc59019c749aa98d0e970a9b6b168",
      "78067f66c8544020aa262c0012aa019b",
      "f5e511b4279747c0a33b10f499a4f718",
      "10e248a5730c4c09834b4065b6dd7843",
      "bbd5a106c1de43e497987648899569e0",
      "118faf3d1ed142038af4783fb3ba86cc"
     ]
    },
    "id": "ActfjFx0lHlK",
    "outputId": "f257e553-7baa-4f1e-b531-a6b7dc358665"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f1028b5dcab4460c81b09802ca7ec89c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/4 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import Dataset\n",
    "\n",
    "sample_data = {\n",
    "    'text': [\n",
    "        'The sun is a star at the center of our solar system.',\n",
    "        'Photosynthesis is the process by which green plants make food.',\n",
    "        'Water freezes at 0 degrees Celsius.',\n",
    "        'The Earth revolves around the sun in 365 days.'\n",
    "    ]\n",
    "}\n",
    "\n",
    "dataset = Dataset.from_dict(sample_data)\n",
    "\n",
    "def tokenize_function(example):\n",
    "    return tokenizer(example['text'], padding='max_length', truncation=True, max_length=64)\n",
    "\n",
    "tokenized_dataset = dataset.map(tokenize_function)\n",
    "tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "x7icWpOnszjj",
   "metadata": {
    "id": "x7icWpOnszjj"
   },
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8mM4PAhRlHlL",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "8mM4PAhRlHlL",
    "outputId": "1ba90639-b847-4cb1-df42-3fdd644e4cb8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "user\n",
      "Explain photosynthesis to a child.\n",
      "model\n",
      "Imagine plants are like tiny chefs, and they cook their own food! \n",
      "\n",
      "They use sunlight, air, and water to make yummy food called sugar.  \n",
      "\n",
      "Here's how it works:\n",
      "\n",
      "1. **Sunlight:** Plants have special green stuff called chlorophyll that acts like a solar panel, soaking up the sun's energy.\n",
      "2. **Air:** Plants take in air through tiny holes in their leaves called stomata.  The air has a gas called carbon dioxide.\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "input_text = \"Explain photosynthesis to a child.\"\n",
    "chat = tokenizer.apply_chat_template(\n",
    "    [{\"role\": \"user\", \"content\": input_text}],\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True\n",
    ")\n",
    "inputs = tokenizer(chat, return_tensors=\"pt\").to(model.device)\n",
    "output = model.generate(**inputs, max_new_tokens=100)\n",
    "print(tokenizer.decode(output[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "Fhkfhbjns2oq",
   "metadata": {
    "id": "Fhkfhbjns2oq"
   },
   "source": [
    "# Save Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "OuJEt2-jlHlL",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "OuJEt2-jlHlL",
    "outputId": "f076e950-cce7-4850-edd9-ff1d84a096d6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('gemma-finetuned-demo/tokenizer_config.json',\n",
       " 'gemma-finetuned-demo/special_tokens_map.json',\n",
       " 'gemma-finetuned-demo/chat_template.jinja',\n",
       " 'gemma-finetuned-demo/tokenizer.model',\n",
       " 'gemma-finetuned-demo/added_tokens.json',\n",
       " 'gemma-finetuned-demo/tokenizer.json')"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.save_pretrained(\"gemma-finetuned-demo\")\n",
    "tokenizer.save_pretrained(\"gemma-finetuned-demo\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "uFK1sCzblHlM",
   "metadata": {
    "id": "uFK1sCzblHlM"
   },
   "source": [
    "# Output\n",
    "`Photosynthesis is how plants eat sunlight! 🌞 They use air, water, and sunlight to make food and grow.`"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
